#!/usr/bin/env python3
import argparse
import asyncio
import json
import os
import re
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Any, Union
import aiohttp  

class SimpleASDivSolver:
    """Simplified ASDiv solver with direct IO mode"""
    
    def __init__(self):
        self.model = "your model"  
        self.base_url = "your base_url"
        self.token_counts = [0, 0] 
        self.stats = {
            "total_problems": 0,
            "correct_answers": 0,
            "incorrect_answers": 0,
            "accuracy": 0.0
        }
    
    async def generate(self, prompt: str) -> str:
        """Call local Ollama API"""
        try:
            async with aiohttp.ClientSession() as session:
                payload = {
                    "model": self.model,
                    "messages": [{"role": "user", "content": prompt}],
                    "temperature": 0.3,
                    "max_tokens": 8000,
                    "top_p": 0.8
                }
                
                async with session.post(
                    f"{self.base_url}/chat/completions",
                    json=payload,
                    timeout=aiohttp.ClientTimeout(total=120)
                ) as response:
                    resp = await response.json()

                    input_tokens = len(prompt) // 4
                    output_tokens = len(resp["choices"][0]["message"]["content"]) // 4
                    self.token_counts[0] += input_tokens
                    self.token_counts[1] += output_tokens
                    
                    return resp["choices"][0]["message"]["content"]
        except Exception as e:
            print(f"LLM Error: {str(e)}")
            raise
    
    def _extract_answer(self, text: str) -> Optional[str]:
        """Extract answer from response text and clean it to keep only digits"""
        # Try to find boxed answer first
        boxed_pattern = r'\\boxed\{([^{}]+)\}'
        boxed_matches = re.findall(boxed_pattern, text)
        if boxed_matches:
            raw_answer = boxed_matches[-1]
        else:
            # Then look for final answer line
            final_answer_match = re.search(
                r'Final\s+Answer\s*:\s*([^\n]+)', 
                text, 
                re.IGNORECASE
            )
            if final_answer_match:
                raw_answer = final_answer_match.group(1).strip()
            else:
                return None
        
        # Clean the answer - remove all non-digit characters, including periods
        cleaned_answer = re.sub(r'[^\d]', '', raw_answer)
        return cleaned_answer if cleaned_answer else None
    
    async def solve_problem(self, question: str) -> Dict[str, Any]:
        """Directly solve a math problem"""
        prompt = f"""
Problem: {question}
Let's think step by step, provide the final answer in the format "Final Answer: [your answer]".
"""
        
        response = await self.generate(prompt)
        answer = self._extract_answer(response)
        
        return {
            "response": response,
            "answer": answer,
            "tokens": self.token_counts.copy()
        }
    
    async def load_problems(self, dataset_path: str, start_idx: int, end_idx: int) -> List[Dict]:
        """Load math problems from dataset"""
        try:
            with open(dataset_path, "r", encoding="utf-8") as f:
                data = json.load(f)
                return data[start_idx:end_idx]
        except Exception as e:
            print(f"Error loading dataset: {str(e)}")
            return []
    
    def _extract_correct_answer(self, solution: str) -> Optional[str]:
        """Extract correct answer from solution with multiple pattern support"""
        if not solution:
            return None
        
        # Pattern 1: #### answer (GSM8K original format)
        hash_pattern = r'####\s*([^\n]+)'
        hash_matches = re.findall(hash_pattern, solution)
        if hash_matches:
            return hash_matches[-1].strip()
        
        # Pattern 2: \boxed{answer} (common LaTeX format)
        boxed_pattern = r'\\boxed\{([^{}]+)\}'
        boxed_matches = re.findall(boxed_pattern, solution)
        if boxed_matches:
            return boxed_matches[-1].strip()
        
        # Pattern 3: Final Answer: answer (explicit declaration)
        final_answer_pattern = r'Final\s+Answer\s*:\s*([^\n]+)'
        final_match = re.search(final_answer_pattern, solution, re.IGNORECASE)
        if final_match:
            return final_match.group(1).strip()
        
        # Pattern 4: The answer is [answer]
        answer_is_pattern = r'The\s+answer\s+is\s+([^\n]+)'
        answer_is_match = re.search(answer_is_pattern, solution, re.IGNORECASE)
        if answer_is_match:
            return answer_is_match.group(1).strip()
        
        return None
    
    def update_stats(self, is_correct: bool):
        """Update statistics"""
        self.stats["total_problems"] += 1
        if is_correct:
            self.stats["correct_answers"] += 1
        else:
            self.stats["incorrect_answers"] += 1
        
        if self.stats["total_problems"] > 0:
            self.stats["accuracy"] = (
                self.stats["correct_answers"] / self.stats["total_problems"] * 100
            )

async def main():
    parser = argparse.ArgumentParser(description="Simple ASDiv Solver")
    parser.add_argument("--start", type=int, default=0, help="Start index in dataset")
    parser.add_argument("--end", type=int, default=1, help="End index in dataset")
    parser.add_argument("--dataset", type=str, default="ASDiv.json", help="Path to dataset")
    args = parser.parse_args()
    
    # Create output directory if it doesn't exist
    os.makedirs("log/asdiv_cot", exist_ok=True)
    
    solver = SimpleASDivSolver()
    problems = await solver.load_problems(args.dataset, args.start, args.end)
    results = []
    
    for idx, problem in enumerate(problems, args.start):
        # Get question from 'text' field and remove 'Answer:' part if present
        question = problem.get("text", "").split("Answer:")[0].strip()
        if not question:
            print(f"\n{'='*50}\nSkipping problem {idx}: No valid question found\n{'='*50}")
            continue
        
        print(f"\n{'='*50}\nProcessing problem {idx}: {question[:50]}...\n{'='*50}")
        
        # Reset token counts for each problem
        solver.token_counts = [0, 0]
        
        result = await solver.solve_problem(question)
        
        # Prepare verification
        correct_answer = str(problem.get("label", "0")).strip()
        is_correct = False
        
        if correct_answer and result["answer"]:
            is_correct = str(result["answer"]).strip() == str(correct_answer).strip()
            solver.update_stats(is_correct)
        
        # Prepare result record
        record = {
            "problem_id": idx,
            "question": question,
            "response": result["response"],
            "answer": result["answer"],
            "correct_answer": correct_answer,
            "is_correct": is_correct,
            "tokens": result["tokens"]
        }
        results.append(record)
        
        print(f"\nExecution Summary:")
        print(f"Answer: {result['answer']}")
        print(f"Correct answer: {correct_answer}")
        print(f"Verification: {'CORRECT' if is_correct else 'INCORRECT'}")
        print(f"Tokens used: {result['tokens']}")
    
    # Save results
    if results:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"log/asdiv_cot/results_{args.start}_{args.end}_{timestamp}_acc{solver.stats['accuracy']:.2f}%.json"
        
        output = {
            "results": results,
            "statistics": solver.stats
        }
        
        with open(filename, "w", encoding="utf-8") as f:
            json.dump(output, f, indent=2, ensure_ascii=False)
        
        print(f"\n{'='*50}\nFinal Statistics\n{'='*50}")
        print(f"Results saved to {filename}")
        print(f"Total problems processed: {solver.stats['total_problems']}")
        print(f"Correct answers: {solver.stats['correct_answers']}")
        print(f"Incorrect answers: {solver.stats['incorrect_answers']}")
        print(f"Overall accuracy: {solver.stats['accuracy']:.2f}%")
        print(f"{'='*50}\n")

if __name__ == "__main__":
    asyncio.run(main())